import re
import numpy as np
from functools import lru_cache
import ast
import torch
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased")

def get_funcs_jaccard_sim(pseudocode_1, pseudocode_2):
    functions_1 = parse_pseudocode(pseudocode_1)
    functions_2 = parse_pseudocode(pseudocode_2)

    return calculate_jaccard(set(functions_1), set(functions_2))

def get_argkeys_jaccard_sim(pseudocode_1, pseudocode_2):
    functions_1 = parse_pseudocode(pseudocode_1)
    functions_2 = parse_pseudocode(pseudocode_2)
    
    func_intersection = set(functions_1).intersection(set(functions_2))

    sims = []

    for func_name in func_intersection:
        args_1 = functions_1[func_name]
        args_2 = functions_2[func_name]
        if args_1 and args_2:
            jaccard_sim = calculate_jaccard(set(args_1), set(args_2))
            sims.append(jaccard_sim)
    if not sims:
        return 0
    return np.mean(sims).item()

def get_funcs_scibert_sim(pseudocode_1, pseudocode_2):
    functions_1 = parse_pseudocode(pseudocode_1)
    functions_2 = parse_pseudocode(pseudocode_2)

    embeddings_1 = [torch.tensor(get_embedding(func)) for func in functions_1]
    embeddings_2 = [torch.tensor(get_embedding(func)) for func in functions_2]

    cos_similarities = []

    for emb_1 in embeddings_1:
        for emb_2 in embeddings_2:
            cos_sim = F.cosine_similarity(emb_1, emb_2, dim=0)
            cos_similarities.append(cos_sim.item())

    return np.mean(cos_similarities) if cos_similarities else 0.0


def parse_pseudocode(code):
    functions = re.findall(r"(\w+)\s*\(([^)]+)\)", code)

    function_dict = {}

    for function in functions:
        function_name = function[0]
        arguments_str = re.findall(r'(\w+)=["\']([^"\']+)["\']', function[1])
        arguments_digit = re.findall(r'(\w+)=(\d+|[\w]+)', function[1])

        argument_dict = dict(arguments_str)
        for func, argkeys in arguments_digit:
            if func not in argument_dict:
                argument_dict[func] = argkeys
                
        function_dict[function_name] = argument_dict

    return append_index_to_function_names(function_dict)

def append_index_to_function_names(func_dict):
    func_count = {}
    new_func_dict = {}

    for func_name, args in func_dict.items():
        if func_name not in func_count:
            func_count[func_name] = 0
        else:
            func_count[func_name] += 1

        new_func_name = func_name + str(func_count[func_name])
        new_func_dict[new_func_name] = args

    return new_func_dict

def calculate_jaccard(set1:set, set2:set):
    # print(set1, set2)
    intersection = set1.intersection(set2)
    union = set1.union(set2)
    return len(intersection) / len(union)

def parse_python_code(code):  # Code is a multi-line string.
    lines = code.split("\n")  # Split the code into lines
    funcs = []  # Store the function names here
    main_code = []
    func = []
    in_func = False
    for line in lines:  # Parse string line by line
        if re.match(r"def [\w_]+\(.*?\):", line):  # start of a function
            in_func = True
            if func:  # if func is not empty, add it to funcs
                funcs.append("\n".join(func))
                func = []
        elif in_func and not re.match(r"\s", line):  # end of a function
            in_func = False
            funcs.append("\n".join(func))
            func = []
        if in_func:
            func.append(line)
        else:
            main_code.append(line)
    if func:  # if func is not empty after looping, add it to funcs
        funcs.append("\n".join(func))

    return funcs, "\n".join(main_code)

def lev_dist(a, b):
    """
    This function will calculate the levenshtein distance between two input
    strings a and b

    params:
        a (String) : The first string you want to compare
        b (String) : The second string you want to compare

    returns:
        This function will return the distnace between string a and b.

    example:
        a = 'stamp'
        b = 'stomp'
        lev_dist(a,b)
        >> 1.0
    """

    @lru_cache(None)  # for memorization
    def min_dist(s1, s2):
        if s1 == len(a) or s2 == len(b):
            return len(a) - s1 + len(b) - s2

        # no change required
        if a[s1] == b[s2]:
            return min_dist(s1 + 1, s2 + 1)

        return 1 + min(
            min_dist(s1, s2 + 1),  # insert character
            min_dist(s1 + 1, s2),  # delete character
            min_dist(s1 + 1, s2 + 1),  # replace character
        )

    return min_dist(0, 0)

def get_levenshtein_distance(pseudofunctions, pseudocode_1, pseudocode_2):
    # Step 1: Get the names of the pseudo-functions
    ALPHABET_STRING = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"  # Can make this longer if need be
    PENALTY_WEIGHT = 1
    pseudo_functions = [ast.parse(x) for x in parse_python_code(pseudofunctions)[0]]
    function_names = [
        node.name
        for tree in pseudo_functions
        for node in ast.walk(tree)
        if isinstance(node, ast.FunctionDef)
    ]
    # Step 1.1: Append a function "undefined_function"
    UNDEFINED_FUNCTION_NAME = "undefined_function"
    function_names.append(UNDEFINED_FUNCTION_NAME)
    # Step 1.2: Define the mapping dictionary from function names to characters
    MAP_DICT = {
        function_names[i]: ALPHABET_STRING[i] for i in range(len(function_names))
    }

    def map_code(
        function_names, pseudocode, map_dict=MAP_DICT, und_f=UNDEFINED_FUNCTION_NAME
    ):
        penalty = 0  # Number of undefined functions used
        tree = ast.parse(pseudocode)
        functions_used = (
            []
        )  # List to be returned (No arguments).  # TODO: Handle if statements
        for node in ast.walk(tree):
            if isinstance(node, ast.Call):  # Function call
                if node.func.id not in function_names:  # Check if defined function
                    # print("Undefined function " + str(node.func.id) + " in code")
                    penalty += 1
                    functions_used.append(und_f)
                else:
                    functions_used.append(node.func.id)
        program_string = "".join(
            [map_dict[f] for f in functions_used]
        )  # Convert into a string
        return program_string, penalty

    prediction_string, prediction_penalty = map_code(function_names, pseudocode_2)
    ground_truth_string, penalty_gt = map_code(function_names, pseudocode_1)

    levenshtein_distance_metric_score = (
        lev_dist(prediction_string, ground_truth_string)
        + prediction_penalty * PENALTY_WEIGHT
    )

    return levenshtein_distance_metric_score

@lru_cache(maxsize=1024)
def get_embedding(text):
    '''get embedding of CLS token'''
    inputs = tokenizer(text, return_tensors='pt', max_length=512, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    embedding = outputs.last_hidden_state[:, 0, :]
    embedding = F.normalize(embedding, p=2, dim=1)
    return embedding.squeeze().numpy()